""" discover and run doctests in Cython extension modules."""

import os
import pathlib
import re
import sysconfig

import pytest

from typing import Any, Iterable, Union

from _pytest.nodes import Collector
from _pytest.doctest import skip, DoctestModule, DoctestItem
from _pytest.pathlib import resolve_package_path, ImportMode


CYTHON_SUFFIXES = ['.py', '.pyx']
EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX")

IGNORE_IMPORTMISMATCH_KEY = "PY_IGNORE_IMPORTMISMATCH"
IGNORE_IMPORTMISMATCH = os.environ.get(IGNORE_IMPORTMISMATCH_KEY, "")


def pytest_addoption(parser: pytest.Parser):
    group = parser.getgroup("cython")

    group.addoption(
        "--doctest-cython",
        action="store_true",
        default=False,
        help="run doctests in all .so and .pyd modules",
        dest="doctest_cython",
    )


def pytest_collect_file(file_path: pathlib.Path, path, parent: pytest.Collector) -> pytest.Module:
    config = parent.config
    if file_path.suffix not in CYTHON_SUFFIXES or not config.getoption('--doctest-cython'):
        return

    bin_path = file_path.with_suffix(EXT_SUFFIX)
    if not bin_path.exists():
        return

    # only run test if matching .so and .pyx files exist
    return _PatchedDoctestModule.from_parent(parent, path=file_path)


class _PatchedDoctestModule(DoctestModule):
    def collect(self) -> Iterable[DoctestItem]:
        mode = ImportMode(self.config.getoption("importmode"))

        if mode is not ImportMode.importlib and IGNORE_IMPORTMISMATCH != "1":
            # we know we will get an import file mismatch error so ignore it for now,
            # we will double check the import paths after import.
            os.environ[IGNORE_IMPORTMISMATCH_KEY] = "1"

        try:
            items = list(super().collect())  # import the module and collect tests
        finally:
            # set ignore variable back to its original value
            os.environ[IGNORE_IMPORTMISMATCH_KEY] = IGNORE_IMPORTMISMATCH

        module = self.obj  # module already imported

        try:
            _check_module_import(module, self.path, mode)
        except Collector.CollectError:
            if self.config.getvalue("doctest_ignore_import_errors"):
                skip("unable to import module %r" % self.path)
            else:
                raise

        return _add_line_numbers(module, items)


def _without_suffixes(path: Union[str, pathlib.Path]) -> pathlib.Path:
    path = pathlib.Path(path)
    return path.with_name(path.name.split('.')[0]).with_suffix('')


def _get_module_name(path: pathlib.Path) -> str:
    pkg_path = resolve_package_path(path)
    if pkg_path is not None:
        pkg_root = pkg_path.parent
        names = list(path.with_suffix("").relative_to(pkg_root).parts)
        if names[-1] == "__init__":
            names.pop()
        module_name = ".".join(names)
    else:
        pkg_root = path.parent
        module_name = path.stem

    return module_name


def _check_module_import(module: Any, path: pathlib.Path, mode: ImportMode) -> None:
    # double check that the only difference is the extension else raise an exception

    if mode is ImportMode.importlib or IGNORE_IMPORTMISMATCH == "1":
        return

    module_name = _get_module_name(path)
    module_file = _without_suffixes(module.__file__)
    import_file = _without_suffixes(path)

    if module_file == import_file:
        return

    raise Collector.CollectError(
        "import file mismatch:\n"
        "imported module %r has this __file__ attribute:\n"
        "  %s\n"
        "which is not the same as the test file we want to collect:\n"
        "  %s\n"
        "HINT: remove __pycache__ / .pyc files and/or use a "
        "unique basename for your test file modules" % (module_name, module_file, import_file)
    )


def _add_line_numbers(module: Any, items: Iterable[DoctestItem]) -> Iterable[DoctestItem]:
    # handle tests from Cython's internal __test__ dict generated by
    # the autotestdict directive; we exclude the tests from __test__,
    # though they do give us a little bonus if they exist: we can extract
    # the line number of the test
    lineno_re = re.compile(r'\(line (\d+)\)')
    test_dict = module.__name__ + '.__test__'
    test_items = {x.name: x for x in items}

    for test_name in list(test_items.keys()):
        if not test_name.startswith(test_dict + '.'):
            continue

        match = lineno_re.search(test_name)
        lineno = int(match.group(1)) if match else None

        # If somehow the equivalent test does not already exist, we
        # keep the __test__ test (maybe it is something else not
        # generated by autotestdict)
        equiv_test_name = test_name.split()[0].replace(test_dict, module.__name__)

        if (equiv_test_name not in test_items or not test_items[equiv_test_name].dtest.examples):
            # for some reason the equivalent test was not found (e.g.
            # the module was compiled with docstrings stripped) so keep
            # the __test__ test but hide the fact that it came from the
            # __test__ dict
            test_items[test_name].name = equiv_test_name
            # set lineno on the __test__ test as well, since normally
            # it is not set by doctest
            test_items[test_name].dtest.lineno = lineno
            continue

        # Delete the __test__ test, but try to update the lineno of the
        # equivalent test
        del test_items[test_name]
        test_items[equiv_test_name].dtest.lineno = lineno

    for test in sorted(test_items.values(), key=lambda x: x.name):
        yield test
